

import numpy as np
import torch
from torch import nn

from global_vars import BOUNDS, PIXEL_SIZE, IN_SHAPE
from models.transporter_blocks import (
    Attention, TwoStreamAttention, TwoStreamAttentionLat,
    TwoStreamAttentionLangFusion, TwoStreamAttentionLangFusionLat,
    OneStreamAttentionLangFusion,
    Transport, TwoStreamTransport, TwoStreamTransportLat,
    TwoStreamTransportLangFusion, TwoStreamTransportLangFusionLat,
    OneStreamTransportLangFusion
)
from models.core.attention_image_goal import AttentionImageGoal
from models.core.transport_image_goal import TransportImageGoal
import utils.transporter_utils as utils

import ipdb
st = ipdb.set_trace


class TransporterAgent(nn.Module):
    """Base class to implement Transporter in PyTorch."""

    def __init__(self, n_rotations, cfg):
        """
        Args:
            - n_rotations (int): number of orientation bins for
                end-effectors pose
        """
        super().__init__()
        utils.set_seed(0)

        self.device_type = 'cuda' if torch.cuda.is_available() else 'cpu'

        self.crop_size = 64
        self.n_rotations = n_rotations  # 36
        self.cfg = cfg

        self.pix_size = PIXEL_SIZE
        self.in_shape = IN_SHAPE
        self.bounds = BOUNDS
        self.real_robot = False

        self._build_model()

    def _build_model(self):
        self.pick_net = None
        self.place_net = None

    def pick_forward(self, img, lang=None, softmax=False):
        return self.pick_net(img, lang, softmax=softmax)

    def place_forward(self, img, pick_pose, lang=None,
            softmax=False, img_pick=None):
        return self.place_net(img, pick_pose, lang, softmax=softmax, img_pick=img_pick)

    def forward(self, img, lang=None, pick_hmap=None, place_hmap=None):
        """Run an inference step given visual observations (H, W, 6)."""
        # Pick model forward pass
        pick_conf = self.pick_forward(img.unsqueeze(0), lang, True)[0]
        pick_conf = pick_conf.detach().cpu()
        tr_pick = np.copy(pick_conf.numpy())
        if pick_hmap is not None:
            # pick_hmap = pick_hmap.clone()
            # pick_hmap[pick_hmap < 1] = -1e+15
            # pick_hmap[pick_hmap == 1] = 0
            pick_conf = pick_conf * pick_hmap[..., None]
            pick_conf = pick_conf + 0.1 * pick_hmap[..., None]
        argmax = np.argmax(pick_conf.numpy())
        argmax = np.unravel_index(argmax, shape=pick_conf.shape)
        p0_pix = argmax[:2]
        p0_theta = argmax[2] * (2 * np.pi / pick_conf.shape[2])

        # Place model forward pass
        place_conf = self.place_forward(
            img.unsqueeze(0),
            (torch.as_tensor([p0_pix[0]]), torch.as_tensor([p0_pix[1]])),
            lang,
            softmax=True
        )[0]
        place_conf = place_conf.detach().cpu()
        tr_place = np.copy(place_conf.numpy())
        if place_hmap is not None:
            # place_hmap = place_hmap.clone()
            # place_hmap[place_hmap < 1] = -1e+15
            # place_hmap[place_hmap == 1] = 0
            place_conf = place_conf * place_hmap[..., None]
            place_conf = place_conf + 0.1 * place_hmap[..., None]

        argmax = np.argmax(place_conf.numpy())
        argmax = np.unravel_index(argmax, shape=place_conf.shape)
        p1_pix = argmax[:2]
        p1_theta = argmax[2] * (2 * np.pi / place_conf.shape[2])

        if self.real_robot:
            p0_pix = torch.tensor([p0_pix[1], p0_pix[0]])
            p1_pix = torch.tensor([p1_pix[1], p1_pix[0]])
            height = height.transpose(1, 0)

        # Pixels to end effector poses
        height = img[..., 3].cpu().numpy()
        p0_xyz = utils.pix_to_xyz(p0_pix, height, self.bounds, self.pix_size)
        p1_xyz = utils.pix_to_xyz(p1_pix, height, self.bounds, self.pix_size)
        p0_xyzw = utils.eulerXYZ_to_quatXYZW((0, 0, -p0_theta))
        p1_xyzw = utils.eulerXYZ_to_quatXYZW((0, 0, -p1_theta))
        return {
            'pose0': (np.asarray(p0_xyz), np.asarray(p0_xyzw)),
            'pose1': (np.asarray(p1_xyz), np.asarray(p1_xyzw)),
            'pick': p0_pix,
            'place': p1_pix
        }, pick_conf, place_conf, tr_pick, tr_place


class OriginalTransporterAgent(TransporterAgent):

    def __init__(self, n_rotations, cfg, pretrained=False):
        super().__init__(n_rotations, cfg)
        stream_fcn = 'plain_resnet'
        self.pick_net = Attention(
            stream_fcn=(stream_fcn, None),
            in_shape=self.in_shape,
            n_rotations=1,
            preprocess=utils.preprocess,
            cfg=self.cfg,
            device=self.device_type,
            pretrained=pretrained
        )
        self.place_net = Transport(
            stream_fcn=(stream_fcn, None),
            in_shape=self.in_shape,
            n_rotations=self.n_rotations,
            crop_size=self.crop_size,
            preprocess=utils.preprocess,
            cfg=self.cfg,
            device=self.device_type,
            pretrained=pretrained
        )


class GCTransporterAgent(TransporterAgent):
    
    def __init__(self, n_rotations, cfg, pretrained=False):
        super().__init__(n_rotations, cfg)
        stream_fcn = 'plain_resnet'
        self.pick_net = Attention(
            stream_fcn=(stream_fcn, None),
            in_shape=self.in_shape,
            n_rotations=1,
            preprocess=utils.preprocess,
            cfg=self.cfg,
            device=self.device_type,
            pretrained=pretrained
        )
        self.place_net = Transport(
            stream_fcn=(stream_fcn, None),
            in_shape=self.in_shape,
            n_rotations=self.n_rotations,
            crop_size=self.crop_size,
            preprocess=utils.preprocess,
            cfg=self.cfg,
            device=self.device_type,
            pretrained=pretrained
        )

    def forward(self, img_pick, img_place):
        """Run an inference step given visual observations (H, W, 6)."""
        # Pick model forward pass
        pick_conf = self.pick_forward(img_pick.unsqueeze(0), None, True)[0]
        pick_conf = pick_conf.detach().cpu()
        tr_pick = np.copy(pick_conf.numpy())

        argmax = np.argmax(pick_conf.numpy())
        argmax = np.unravel_index(argmax, shape=pick_conf.shape)
        p0_pix = argmax[:2]
        p0_theta = argmax[2] * (2 * np.pi / pick_conf.shape[2])

        # Place model forward pass
        place_conf = self.place_forward(
            img_place.unsqueeze(0),
            (torch.as_tensor([p0_pix[0]]), torch.as_tensor([p0_pix[1]])),
            lang=None,
            softmax=True,
            img_pick=img_pick.unsqueeze(0)
        )[0]
        place_conf = place_conf.detach().cpu()
        tr_place = np.copy(place_conf.numpy())

        argmax = np.argmax(place_conf.numpy())
        argmax = np.unravel_index(argmax, shape=place_conf.shape)
        p1_pix = argmax[:2]
        p1_theta = argmax[2] * (2 * np.pi / place_conf.shape[2])

        # Pixels to end effector poses
        height_pick = img_pick[..., 3].cpu().numpy()
        height_place = img_place[..., 3].cpu().numpy()
        if self.real_robot:
            p0_pix = torch.tensor([p0_pix[1], p0_pix[0]])
            p1_pix = torch.tensor([p1_pix[1], p1_pix[0]])
            height_pick = height_pick.transpose(1, 0)
            height_place = height_place.transpose(1, 0)

        p0_xyz = utils.pix_to_xyz(p0_pix, height_pick, self.bounds, self.pix_size)
        p1_xyz = utils.pix_to_xyz(p1_pix, height_place, self.bounds, self.pix_size)
        p0_xyzw = utils.eulerXYZ_to_quatXYZW((0, 0, -p0_theta))
        p1_xyzw = utils.eulerXYZ_to_quatXYZW((0, 0, -p1_theta))
        return {
            'pose0': (np.asarray(p0_xyz), np.asarray(p0_xyzw)),
            'pose1': (np.asarray(p1_xyz), np.asarray(p1_xyzw)),
            'pick': p0_pix,
            'place': p1_pix,
        }, pick_conf, place_conf, tr_pick, tr_place



class ClipUNetTransporterAgent(TransporterAgent):

    def __init__(self, name, cfg):
        super().__init__(name, cfg)

    def _build_model(self):
        stream_fcn = 'clip_unet'
        self.pick_net = Attention(
            stream_fcn=(stream_fcn, None),
            in_shape=self.in_shape,
            n_rotations=1,
            preprocess=utils.preprocess,
            cfg=self.cfg,
            device=self.device_type,
        )
        self.place_net = Transport(
            stream_fcn=(stream_fcn, None),
            in_shape=self.in_shape,
            n_rotations=self.n_rotations,
            crop_size=self.crop_size,
            preprocess=utils.preprocess,
            cfg=self.cfg,
            device=self.device_type,
        )


class TwoStreamClipUNetTransporterAgent(TransporterAgent):

    def __init__(self, name, cfg):
        super().__init__(name, cfg)

    def _build_model(self):
        stream_one_fcn = 'plain_resnet'
        stream_two_fcn = 'clip_unet'
        self.pick_net = TwoStreamAttention(
            stream_fcn=(stream_one_fcn, stream_two_fcn),
            in_shape=self.in_shape,
            n_rotations=1,
            preprocess=utils.preprocess,
            cfg=self.cfg,
            device=self.device_type,
        )
        self.place_net = TwoStreamTransport(
            stream_fcn=(stream_one_fcn, stream_two_fcn),
            in_shape=self.in_shape,
            n_rotations=self.n_rotations,
            crop_size=self.crop_size,
            preprocess=utils.preprocess,
            cfg=self.cfg,
            device=self.device_type,
        )


class TwoStreamClipUNetLatTransporterAgent(TransporterAgent):

    def __init__(self, name, cfg):
        super().__init__(name, cfg)

    def _build_model(self):
        stream_one_fcn = 'plain_resnet_lat'
        stream_two_fcn = 'clip_unet_lat'
        self.pick_net = TwoStreamAttentionLat(
            stream_fcn=(stream_one_fcn, stream_two_fcn),
            in_shape=self.in_shape,
            n_rotations=1,
            preprocess=utils.preprocess,
            cfg=self.cfg,
            device=self.device_type,
        )
        self.place_net = TwoStreamTransportLat(
            stream_fcn=(stream_one_fcn, stream_two_fcn),
            in_shape=self.in_shape,
            n_rotations=self.n_rotations,
            crop_size=self.crop_size,
            preprocess=utils.preprocess,
            cfg=self.cfg,
            device=self.device_type,
        )


class TwoStreamClipWithoutSkipsTransporterAgent(TransporterAgent):

    def __init__(self, name, cfg):
        super().__init__(name, cfg)

    def _build_model(self):
        # TODO: lateral version
        stream_one_fcn = 'plain_resnet'
        stream_two_fcn = 'clip_woskip'
        self.pick_net = TwoStreamAttention(
            stream_fcn=(stream_one_fcn, stream_two_fcn),
            in_shape=self.in_shape,
            n_rotations=1,
            preprocess=utils.preprocess,
            cfg=self.cfg,
            device=self.device_type,
        )
        self.place_net = TwoStreamTransport(
            stream_fcn=(stream_one_fcn, stream_two_fcn),
            in_shape=self.in_shape,
            n_rotations=self.n_rotations,
            crop_size=self.crop_size,
            preprocess=utils.preprocess,
            cfg=self.cfg,
            device=self.device_type,
        )


class TwoStreamRN50BertUNetTransporterAgent(TransporterAgent):

    def __init__(self, name, cfg):
        super().__init__(name, cfg)

    def _build_model(self):
        # TODO: lateral version
        stream_one_fcn = 'plain_resnet'
        stream_two_fcn = 'rn50_bert_unet'
        self.pick_net = TwoStreamAttention(
            stream_fcn=(stream_one_fcn, stream_two_fcn),
            in_shape=self.in_shape,
            n_rotations=1,
            preprocess=utils.preprocess,
            cfg=self.cfg,
            device=self.device_type,
        )
        self.place_net = TwoStreamTransport(
            stream_fcn=(stream_one_fcn, stream_two_fcn),
            in_shape=self.in_shape,
            n_rotations=self.n_rotations,
            crop_size=self.crop_size,
            preprocess=utils.preprocess,
            cfg=self.cfg,
            device=self.device_type,
        )


class TwoStreamClipLingUNetTransporterAgent(TransporterAgent):
    def __init__(self, n_rotations, cfg):
        super().__init__(n_rotations, cfg)

    def _build_model(self):
        stream_one_fcn = 'plain_resnet'
        stream_two_fcn = 'clip_lingunet'
        self.pick_net = TwoStreamAttentionLangFusion(
            stream_fcn=(stream_one_fcn, stream_two_fcn),
            in_shape=self.in_shape,
            n_rotations=1,
            preprocess=utils.preprocess,
            cfg=self.cfg,
            device=self.device_type,
        )
        self.place_net = TwoStreamTransportLangFusion(
            stream_fcn=(stream_one_fcn, stream_two_fcn),
            in_shape=self.in_shape,
            n_rotations=self.n_rotations,
            crop_size=self.crop_size,
            preprocess=utils.preprocess,
            cfg=self.cfg,
            device=self.device_type,
        )


class TwoStreamClipFilmLingUNetLatTransporterAgent(TwoStreamClipLingUNetTransporterAgent):
    def __init__(self, name, cfg, train_ds, test_ds):
        super().__init__(name, cfg, train_ds, test_ds)

    def _build_model(self):
        stream_one_fcn = 'plain_resnet_lat'
        stream_two_fcn = 'clip_film_lingunet_lat'
        self.pick_net = TwoStreamAttentionLangFusionLat(
            stream_fcn=(stream_one_fcn, stream_two_fcn),
            in_shape=self.in_shape,
            n_rotations=1,
            preprocess=utils.preprocess,
            cfg=self.cfg,
            device=self.device_type,
        )
        self.place_net = TwoStreamTransportLangFusionLat(
            stream_fcn=(stream_one_fcn, stream_two_fcn),
            in_shape=self.in_shape,
            n_rotations=self.n_rotations,
            crop_size=self.crop_size,
            preprocess=utils.preprocess,
            cfg=self.cfg,
            device=self.device_type,
        )


class TwoStreamClipLingUNetLatTransporterAgent(TwoStreamClipLingUNetTransporterAgent):
    def __init__(self, n_rotations, cfg):
        super().__init__(n_rotations, cfg)

    def _build_model(self):
        stream_one_fcn = 'plain_resnet_lat'
        stream_two_fcn = 'clip_lingunet_lat'
        self.pick_net = TwoStreamAttentionLangFusionLat(
            stream_fcn=(stream_one_fcn, stream_two_fcn),
            in_shape=self.in_shape,
            n_rotations=1,
            preprocess=utils.preprocess,
            cfg=self.cfg,
            device=self.device_type,
        )
        self.place_net = TwoStreamTransportLangFusionLat(
            stream_fcn=(stream_one_fcn, stream_two_fcn),
            in_shape=self.in_shape,
            n_rotations=self.n_rotations,
            crop_size=self.crop_size,
            preprocess=utils.preprocess,
            cfg=self.cfg,
            device=self.device_type,
        )


class TwoStreamRN50BertLingUNetTransporterAgent(TwoStreamClipLingUNetTransporterAgent):
    def __init__(self, name, cfg, train_ds, test_ds):
        super().__init__(name, cfg, train_ds, test_ds)

    def _build_model(self):
        stream_one_fcn = 'plain_resnet'
        stream_two_fcn = 'rn50_bert_lingunet'
        self.pick_net = TwoStreamAttentionLangFusion(
            stream_fcn=(stream_one_fcn, stream_two_fcn),
            in_shape=self.in_shape,
            n_rotations=1,
            preprocess=utils.preprocess,
            cfg=self.cfg,
            device=self.device_type,
        )
        self.place_net = TwoStreamTransportLangFusion(
            stream_fcn=(stream_one_fcn, stream_two_fcn),
            in_shape=self.in_shape,
            n_rotations=self.n_rotations,
            crop_size=self.crop_size,
            preprocess=utils.preprocess,
            cfg=self.cfg,
            device=self.device_type,
        )


class TwoStreamUntrainedRN50BertLingUNetTransporterAgent(TwoStreamClipLingUNetTransporterAgent):
    def __init__(self, name, cfg, train_ds, test_ds):
        super().__init__(name, cfg, train_ds, test_ds)

    def _build_model(self):
        stream_one_fcn = 'plain_resnet'
        stream_two_fcn = 'untrained_rn50_bert_lingunet'
        self.pick_net = TwoStreamAttentionLangFusion(
            stream_fcn=(stream_one_fcn, stream_two_fcn),
            in_shape=self.in_shape,
            n_rotations=1,
            preprocess=utils.preprocess,
            cfg=self.cfg,
            device=self.device_type,
        )
        self.place_net = TwoStreamTransportLangFusion(
            stream_fcn=(stream_one_fcn, stream_two_fcn),
            in_shape=self.in_shape,
            n_rotations=self.n_rotations,
            crop_size=self.crop_size,
            preprocess=utils.preprocess,
            cfg=self.cfg,
            device=self.device_type,
        )


class TwoStreamRN50BertLingUNetLatTransporterAgent(TwoStreamClipLingUNetTransporterAgent):
    def __init__(self, name, cfg, train_ds, test_ds):
        super().__init__(name, cfg, train_ds, test_ds)

    def _build_model(self):
        stream_one_fcn = 'plain_resnet_lat'
        stream_two_fcn = 'rn50_bert_lingunet_lat'
        self.pick_net = TwoStreamAttentionLangFusionLat(
            stream_fcn=(stream_one_fcn, stream_two_fcn),
            in_shape=self.in_shape,
            n_rotations=1,
            preprocess=utils.preprocess,
            cfg=self.cfg,
            device=self.device_type,
        )
        self.place_net = TwoStreamTransportLangFusionLat(
            stream_fcn=(stream_one_fcn, stream_two_fcn),
            in_shape=self.in_shape,
            n_rotations=self.n_rotations,
            crop_size=self.crop_size,
            preprocess=utils.preprocess,
            cfg=self.cfg,
            device=self.device_type,
        )


class OriginalTransporterLangFusionAgent(TwoStreamClipLingUNetTransporterAgent):

    def __init__(self, name, cfg, train_ds, test_ds):
        super().__init__(name, cfg, train_ds, test_ds)

    def _build_model(self):
        stream_fcn = 'plain_resnet_lang'
        self.pick_net = OneStreamAttentionLangFusion(
            stream_fcn=(stream_fcn, None),
            in_shape=self.in_shape,
            n_rotations=1,
            preprocess=utils.preprocess,
            cfg=self.cfg,
            device=self.device_type,
        )
        self.place_net = OneStreamTransportLangFusion(
            stream_fcn=(stream_fcn, None),
            in_shape=self.in_shape,
            n_rotations=self.n_rotations,
            crop_size=self.crop_size,
            preprocess=utils.preprocess,
            cfg=self.cfg,
            device=self.device_type,
        )



class ClipLingUNetTransporterAgent(TwoStreamClipLingUNetTransporterAgent):

    def __init__(self, name, cfg, train_ds, test_ds):
        super().__init__(name, cfg, train_ds, test_ds)

    def _build_model(self):
        stream_fcn = 'clip_lingunet'
        self.pick_net = OneStreamAttentionLangFusion(
            stream_fcn=(stream_fcn, None),
            in_shape=self.in_shape,
            n_rotations=1,
            preprocess=utils.preprocess,
            cfg=self.cfg,
            device=self.device_type,
        )
        self.place_net = OneStreamTransportLangFusion(
            stream_fcn=(stream_fcn, None),
            in_shape=self.in_shape,
            n_rotations=self.n_rotations,
            crop_size=self.crop_size,
            preprocess=utils.preprocess,
            cfg=self.cfg,
            device=self.device_type,
        )


class ImageGoalTransporterAgent(OriginalTransporterAgent):
    def __init__(self, name, cfg, train_ds, test_ds):
        super().__init__(name, cfg, train_ds, test_ds)

    def _build_model(self):
        stream_fcn = 'plain_resnet'
        self.pick_net = AttentionImageGoal(
            stream_fcn=(stream_fcn, None),
            in_shape=self.in_shape,
            n_rotations=1,
            preprocess=utils.preprocess,
            cfg=self.cfg,
            device=self.device_type,
        )
        self.place_net = TransportImageGoal(
            stream_fcn=(stream_fcn, None),
            in_shape=self.in_shape,
            n_rotations=self.n_rotations,
            crop_size=self.crop_size,
            preprocess=utils.preprocess,
            cfg=self.cfg,
            device=self.device_type,
        )

    def attn_forward(self, inp, softmax=True):
        inp_img = inp['inp_img']
        goal_img = inp['goal_img']

        out = self.pick.forward(inp_img, goal_img, softmax=softmax)
        return out

    def attn_training_step(self, frame, goal, backprop=True, compute_err=False):
        inp_img = frame['img']
        goal_img = goal['img']
        p0, p0_theta = frame['p0'], frame['p0_theta']

        inp = {'inp_img': inp_img, 'goal_img': goal_img}
        out = self.attn_forward(inp, softmax=False)
        return self.attn_criterion(backprop, compute_err, inp, out, p0, p0_theta)

    def trans_forward(self, inp, softmax=True):
        inp_img = inp['inp_img']
        goal_img = inp['goal_img']
        p0 = inp['p0']

        out = self.place.forward(inp_img, goal_img, p0, softmax=softmax)
        return out

    def transport_training_step(self, frame, goal, backprop=True, compute_err=False):
        inp_img = frame['img']
        goal_img = goal['img']
        p0 = frame['p0']
        p1, p1_theta = frame['p1'], frame['p1_theta']

        inp = {'inp_img': inp_img, 'goal_img': goal_img, 'p0': p0}
        out = self.trans_forward(inp, softmax=False)
        err, loss = self.transport_criterion(backprop, compute_err, inp, out, p0, p1, p1_theta)
        return loss, err

    def training_step(self, batch, batch_idx):
        self.pick.train()
        self.place.train()

        frame, goal = batch

        # Get training losses.
        step = self.total_steps + 1
        loss0, err0 = self.attn_training_step(frame, goal)
        if isinstance(self.place, Attention):
            loss1, err1 = self.attn_training_step(frame, goal)
        else:
            loss1, err1 = self.transport_training_step(frame, goal)
        total_loss = loss0 + loss1
        self.log('tr/attn/loss', loss0)
        self.log('tr/trans/loss', loss1)
        self.log('tr/loss', total_loss)
        self.total_steps = step

        self.trainer.train_loop.running_loss.append(total_loss)

        self.check_save_iteration()

        return dict(
            loss=total_loss,
        )

    def validation_step(self, batch, batch_idx):
        self.pick.eval()
        self.place.eval()

        loss0, loss1 = 0, 0
        for i in range(self.val_repeats):
            frame, goal = batch
            l0, err0 = self.attn_training_step(frame, goal, backprop=False, compute_err=True)
            loss0 += l0
            if isinstance(self.place, Attention):
                l1, err1 = self.attn_training_step(frame, goal, backprop=False, compute_err=True)
                loss1 += l1
            else:
                l1, err1 = self.transport_training_step(frame, goal, backprop=False, compute_err=True)
                loss1 += l1
        loss0 /= self.val_repeats
        loss1 /= self.val_repeats
        val_total_loss = loss0 + loss1

        self.trainer.evaluation_loop.trainer.train_loop.running_loss.append(val_total_loss)

        return dict(
            val_loss=val_total_loss,
            val_loss0=loss0,
            val_loss1=loss1,
            val_attn_dist_err=err0['dist'],
            val_attn_theta_err=err0['theta'],
            val_trans_dist_err=err1['dist'],
            val_trans_theta_err=err1['theta'],
        )

    def act(self, obs, info=None, goal=None):  # pylint: disable=unused-argument
        """Run inference and return best action given visual observations."""
        # Get heightmap from RGB-D images.
        img = self.test_ds.get_image(obs)
        goal_img = self.test_ds.get_image(goal[0])

        # Attention model forward pass.
        pick_conf = self.pick.forward(img, goal_img)
        pick_conf = pick_conf.detach().cpu().numpy()
        argmax = np.argmax(pick_conf)
        argmax = np.unravel_index(argmax, shape=pick_conf.shape)
        p0_pix = argmax[:2]
        p0_theta = argmax[2] * (2 * np.pi / pick_conf.shape[2])

        # Transport model forward pass.
        place_conf = self.place.forward(img, goal_img, p0_pix)
        place_conf = place_conf.permute(1, 2, 0)
        place_conf = place_conf.detach().cpu().numpy()
        argmax = np.argmax(place_conf)
        argmax = np.unravel_index(argmax, shape=place_conf.shape)
        p1_pix = argmax[:2]
        p1_theta = argmax[2] * (2 * np.pi / place_conf.shape[2])

        # Pixels to end effector poses.
        hmap = img[:, :, 3]
        p0_xyz = utils.pix_to_xyz(p0_pix, hmap, self.bounds, self.pix_size)
        p1_xyz = utils.pix_to_xyz(p1_pix, hmap, self.bounds, self.pix_size)
        p0_xyzw = utils.eulerXYZ_to_quatXYZW((0, 0, -p0_theta))
        p1_xyzw = utils.eulerXYZ_to_quatXYZW((0, 0, -p1_theta))

        return {
            'pose0': (np.asarray(p0_xyz), np.asarray(p0_xyzw)),
            'pose1': (np.asarray(p1_xyz), np.asarray(p1_xyzw)),
            'pick': p0_pix,
            'place': p1_pix,
        }
